# Copyright 2021-2023 @ Shenzhen Bay Laboratory &
#                       Peking University &
#                       Huawei Technologies Co., Ltd
#
# This code is a part of MindSPONGE:
# MindSpore Simulation Package tOwards Next Generation molecular modelling.
#
# MindSPONGE is open-source software based on the AI-framework:
# PyTorch (https://pytorch.org/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
BT
"""

from typing import Union, List
import torch
from torch import Tensor, nn
from torch.nn import Parameter

from .energy import WithEnergyCell
from ...partition import NeighbourList
from ...system import Molecule
from ...potential import PotentialCell
from ...potential.bias import Bias
from ...sampling.wrapper import EnergyWrapper
from ...function import get_tensor, keepdims_mean


class FreeEnergyEstimator(WithEnergyCell):
    r"""Wrapper Cell for free energy estimator, which is a subclass of `WithEnergyCell`

    Args:

        system (Molecule): Simulation system.

        potential (PotentialCell): Potential energy function cell.

        kernel (Union[Bias, List[Bias]]): Kernel function cell.

        temperature (float): Simulation temperature.

        bias (Union[Bias, List[Bias]]): Bias potential function cell. Default: None

        cutoff (float): Cut-off distance for neighbour list. If None is given, it will be assigned
            as the cutoff value of the of potential energy. Defulat: None

        neighbour_list (NeighbourList): Neighbour list. Default: None

        wrapper (EnergyWrapper): Network to wrap and process potential and bias. Default: None

    Supported Platforms:

        ``CPU`` ``GPU``

    Symbols:

        B:  Batchsize, i.e. number of walkers of the simulation.

        A:  Number of the atoms in the simulation system.

        N:  Number of the maximum neighbouring atoms.

        U:  Number of potential energy terms.

        V:  Number of bias potential terms.

    """

    def __init__(self,
                 system: Molecule,
                 potential: PotentialCell,
                 kernel: Union[Bias, List[Bias]],
                 temperature: float,
                 bias: Union[Bias, List[Bias]] = None,
                 cutoff: float = None,
                 neighbour_list: NeighbourList = None,
                 wrapper: EnergyWrapper = None,
                 ):

        super().__init__(
            system=system,
            potential=potential,
            bias=bias,
            cutoff=cutoff,
            neighbour_list=neighbour_list,
            wrapper=wrapper,
        )

        self._num_kernels = 0
        self._kernel_names = []
        self.kernel_function: List[Bias] = None
        if isinstance(kernel, list):
            self._num_kernels = len(kernel)
            self.kernel_function = nn.ModuleList(kernel)
        elif isinstance(kernel, nn.Module):
            self._num_kernels = 1
            self.kernel_function = nn.ModuleList([kernel])
        else:
            raise TypeError(f'kernel must be nn.Module or list but got: {type(kernel)}')

        for i in range(self._num_kernels):
            self._kernel_names.append(self.kernel_function[i].name)

        temperature = get_tensor(temperature, torch.float32)

        self.temperature = Parameter(temperature, requires_grad=False)

        self.boltzmann = self.units.boltzmann

        self._kernels = get_tensor(torch.zeros((self.num_walker, self._num_kernels), dtype=torch.float32))
        self._kernel = get_tensor(torch.zeros((self.num_walker, 1), dtype=torch.float32))

    def set_temperature(self, temperature: Tensor) -> Tensor:
        """set simulation temperature"""
        with torch.no_grad():
            self.temperature.copy_(temperature)
        return self.temperature

    def forward(self, *inputs) -> Tensor:
        """calculate the total potential energy (potential energy and bias potential) of the simulation system.

        Return:
            beta_energy (Tensor):   Tensor of shape `(B, 1)`. Data type is float.
                                    :math:`\beta E(R)`.

        Symbols:
            B:  Batchsize, i.e. number of walkers of the simulation.

        """

        #pylint: disable=unused-argument

        coordinate, pbc_box = self.system()

        neigh_idx, neigh_vec, neigh_dis, neigh_mask = self.neighbour_list(coordinate, pbc_box)

        coordinate *= self.length_unit_scale
        if pbc_box is not None:
            pbc_box *= self.length_unit_scale

        if neigh_idx is not None:
            neigh_vec *= self.length_unit_scale
            neigh_dis *= self.length_unit_scale

        energies = self.potential_function(
            coordinate=coordinate,
            neighbour_index=neigh_idx,
            neighbour_mask=neigh_mask,
            neighbour_vector=neigh_vec,
            neighbour_distance=neigh_dis,
            pbc_box=pbc_box
        )

        with torch.no_grad():
            self._energies.copy_(energies)

        biases = None
        if self.bias_function is not None:
            biases = []
            for i in range(self._num_biases):
                bias_ = self.bias_function[i](
                    coordinate=coordinate,
                    neighbour_index=neigh_idx,
                    neighbour_mask=neigh_mask,
                    neighbour_vector=neigh_vec,
                    neighbour_distance=neigh_dis,
                    pbc_box=pbc_box
                )
                biases.append(bias_)

            biases = torch.cat(biases, dim=-1)
            with torch.no_grad():
                self._biases.copy_(biases)

        energy, bias = self.energy_wrapper(energies, biases)

        if self.bias_function is not None:
            with torch.no_grad():
                self._bias.copy_(bias)

        # (B, 1)
        kernel = []
        for i in range(self._num_kernels):
            kernel_ = self.kernel_function[i](
                coordinate=coordinate,
                neighbour_index=neigh_idx,
                neighbour_mask=neigh_mask,
                neighbour_vector=neigh_vec,
                neighbour_distance=neigh_dis,
                pbc_box=pbc_box
            )
            kernel.append(kernel_)
        kernel = torch.cat(kernel, dim=-1)

        # (B, 1) + (B, 1)
        vfe = kernel + energy * torch.reciprocal(self.boltzmann * self.temperature)

        # (B, 1) <- (1, 1)
        return keepdims_mean(vfe, 0)
